import argparse
from engine import *


def get_args():
    parser = argparse.ArgumentParser(description='Train the Net',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--ce', action='store_true', help='CETrainer')
    parser.add_argument('--ce3d', action='store_true', help='CE3DTrainer')
    parser.add_argument('--md', action='store_true', help='MDTrainer')
    parser.add_argument('--md3d', action='store_true', help='MD3DTrainer')

    return parser.parse_args()


if __name__ == '__main__':
    opt = get_args()
    if opt.ce == True:
        trainer = CETrainer()
    elif opt.ce3d == True:
        trainer = CE3DTrainer()
    elif opt.md == True:
        trainer = MDTrainer()
    elif opt.md3d == True:
        trainer = MD3DTrainer()
    else:
        ValueError('需要传参选择要训练的模型！')

    trainer.train()
    del trainer